

from __future__ import annotations

import argparse
import json
import os
from pathlib import Path
from typing import Dict, List

# ---------------------------------------------------------------------------- #
# Clubbed-emotion mapping helpers                                              #
# ---------------------------------------------------------------------------- #

# Fine-grained → coarse buckets
positive_emotions = [
    'active', 'amazed', 'amused', 'calm', 'cheerful', 'confident', 'conscious',
    'creative', 'eager', 'educated', 'emotional', 'empathetic', 'fashionable',
    'feminine', 'grateful', 'inspired', 'loving', 'manly', 'persuaded', 'proud',
    'thrifty', 'youthful', 'excited', 'relieved', 'intrigued', 'elated',
    'mesmerized', 'perplexed', 'overwhelmed'
]

negative_emotions = ['afraid', 'disturbed', 'jealous', 'pessimistic', 'sad']
fear_emotions     = ['afraid', 'alarmed', 'alert']
anger_emotions    = ['angry']
disgust_emotions  = ['disturbed']

emotions_dict: Dict[str, List[str]] = {
    "Positive Emotions": positive_emotions,
    "Negative Emotions": negative_emotions,
    "Fear": fear_emotions,
    "Anger": anger_emotions,
    "Disgust": disgust_emotions,
}


def to_coarse(emotion: str) -> str:
    """Map fine-grained emotion to a coarse bucket."""
    emotion = emotion.lower().strip()
    for bucket, members in emotions_dict.items():
        if emotion in members:
            return bucket
    return "Unclear"

# Coarse label order (must match annotation indices)
EMOTION_ORDER: List[str] = [
    "Positive Emotions",
    "Negative Emotions",
    "Fear",
    "Anger",
    "Disgust",
    "Unclear",
]

OPTION_TO_EMOTION: Dict[int, str] = {idx + 1: emo for idx, emo in enumerate(EMOTION_ORDER)}

# Legacy fine-grained order (1-based) used by older annotation files
FINE_EMOTION_ORDER: List[str] = [
    "active", "afraid", "alarmed", "alert", "amazed", "amused", "angry", "calm",
    "cheerful", "confident", "conscious", "creative", "disturbed", "eager",
    "educated", "emotional", "empathetic", "fashionable", "feminine", "grateful",
    "inspired", "jealous", "loving", "manly", "persuaded", "pessimistic",
    "proud", "sad", "thrifty", "youthful",
]

FINE_OPTION_TO_EMOTION = {idx + 1: emo for idx, emo in enumerate(FINE_EMOTION_ORDER)}


# ---------------------------------------------------------------------------- #
# Utility functions                                                             #
# ---------------------------------------------------------------------------- #


def load_ground_truth(path: Path) -> Dict[str, str]:
    """Load ground-truth mapping *video_id* -> *emotion string*.

    Parameters
    ----------
    path : Path
        Path to *annotation.json* file.

    Returns
    -------
    Dict[str, str]
        Dictionary mapping video id to canonical emotion key.
    """
    with path.open("r", encoding="utf-8") as f:
        raw: Dict[str, int] = json.load(f)

    gt: Dict[str, str] = {}
    for vid, opt in raw.items():
        if not isinstance(opt, int):
            # Sometimes numbers are encoded as strings; try to cast.
            try:
                opt_int = int(opt)
            except (TypeError, ValueError):
                raise ValueError(f"Invalid option value {opt!r} for video {vid!r} in ground truth.")
        else:
            opt_int = opt

        emo = OPTION_TO_EMOTION.get(opt_int)
        # If not found in coarse mapping, check fine-grained mapping and convert
        if emo is None:
            fine = FINE_OPTION_TO_EMOTION.get(opt_int)
            if fine is None:
                raise KeyError(
                    f"No emotion mapping found for option {opt_int} (video id: {vid})."
                )
            emo = to_coarse(fine)
        if emo != "Unclear":
            gt[vid] = emo  # keep only clear coarse labels
    return gt


def load_predictions(pred_dir: Path) -> List[tuple[str, str]]:
    """Load predictions from **all** *.json* files inside *pred_dir*.

    Returns a list of *(video_id, emotion)* pairs **including duplicates** so that
    we can compute accuracy both with and without deduplication.
    """
    records_all: List[tuple[str, str]] = []

    for json_path in pred_dir.glob("*.json"):
        try:
            with json_path.open("r", encoding="utf-8") as f:
                data = json.load(f)
        except json.JSONDecodeError as e:
            print(f"[WARN] Failed to parse {json_path.name}: {e}")
            continue

        # Some prediction files are a single object, others are a list of objects.
        if isinstance(data, list):
            recs_in_file = data
        else:
            recs_in_file = [data]

        for rec in recs_in_file:
            if not isinstance(rec, dict):
                print(f"[WARN] Unexpected record type in {json_path.name}: {type(rec).__name__}; skipping.")
                continue

            video_id: str | None = rec.get("video_id")
            if not video_id:
                video_id = json_path.stem  # fallback

            # accept either key name (fine-grained)
            fine_emotion: str | None = rec.get("final_topic") or rec.get("predicted_topic")
            if not fine_emotion:
                print(f"[WARN] Missing emotion prediction for video {video_id} in {json_path.name}; skipping.")
                continue

            coarse = to_coarse(fine_emotion)
            if coarse == "Unclear":
                continue  # skip ambiguous
            records_all.append((video_id, coarse))

    return records_all


# ------------------------------------------------------------------------- #
# Accuracy helpers
# ------------------------------------------------------------------------- #


def compute_accuracy_records(records: List[tuple[str, str]], gt: Dict[str, str]) -> tuple[int, int]:
    """Compute (correct, total) for a list of prediction records (may contain duplicates)."""
    correct = 0
    for vid, emotion_pred in records:
        gt_label = gt.get(vid)
        if gt_label is None:
            continue  # skip ids not in ground truth or filtered out
        if gt_label == emotion_pred:
            correct += 1
    return correct, len(records)


def compute_accuracy_unique(pred_unique: Dict[str, str], gt: Dict[str, str]) -> tuple[int, int]:
    """Compute (correct, total) using a mapping of unique video_id -> emotion."""
    correct = 0
    total = 0
    for vid, emotion_pred in pred_unique.items():
        emotion_gt = gt.get(vid)
        if emotion_gt is None:
            continue  # Unknown or unclear in ground truth
        total += 1
        if emotion_pred == emotion_gt:
            correct += 1
    return correct, total


def main() -> None:
    parser = argparse.ArgumentParser(description="Evaluate emotion prediction accuracy.")
    parser.add_argument("--pred_dir", type=str, required=True, help="Directory containing prediction JSON files.")
    parser.add_argument("--annot_file", type=str, required=True, help="Path to emotion_annotation.json ground-truth file.")
    parser.add_argument("--output", type=str, default="np_metric.txt", help="File to write accuracy to (default: metric.txt)")
    args = parser.parse_args()

    pred_dir = Path(args.pred_dir)
    annot_path = Path(args.annot_file)
    out_path = Path(args.output)

    if not pred_dir.is_dir():
        raise NotADirectoryError(f"Prediction directory not found: {pred_dir}")
    if not annot_path.is_file():
        raise FileNotFoundError(f"Annotation file not found: {annot_path}")

    # Load data
    gt_map = load_ground_truth(annot_path)

    # Evaluate each JSON file separately
    def load_predictions_file(path: Path) -> List[tuple[str, str]]:
        """Load predictions only from the given JSON file path."""
        try:
            with path.open("r", encoding="utf-8") as f:
                data = json.load(f)
        except Exception as e:
            print(f"[WARN] Could not read {path.name}: {e}")
            return []

        recs = data if isinstance(data, list) else [data]
        pairs: List[tuple[str, str]] = []
        for rec in recs:
            if not isinstance(rec, dict):
                continue
            vid = rec.get("video_id") or path.stem
            fine = rec.get("final_topic") or rec.get("predicted_topic")
            if not fine:
                continue
            coarse = to_coarse(fine)
            if coarse == "Unclear":
                continue
            pairs.append((vid, coarse))
        return pairs

    for json_path in sorted(pred_dir.glob("*.json")):
        pred_records_file = load_predictions_file(json_path)

        # Compute accuracy including duplicate predictions
        correct_with_dup, total_with_dup = compute_accuracy_records(pred_records_file, gt_map)
        accuracy_with_dup = (correct_with_dup / total_with_dup) if total_with_dup else 0.0

        # Build unique prediction map (last prediction wins) & log overrides
        pred_unique: Dict[str, str] = {}
        for vid, emo in pred_records_file:
            if vid in pred_unique and pred_unique[vid] != emo:
                print(
                    f"[INFO] Duplicate prediction for {vid} – overriding '{pred_unique[vid]}' with '{emo}'."
                )
            pred_unique[vid] = emo

        # Compute accuracy on unique predictions
        correct_unique, total_unique = compute_accuracy_unique(pred_unique, gt_map)
        accuracy_unique = (correct_unique / total_unique) if total_unique else 0.0

        # Prepare output
        lines = [
            f"Accuracy (with duplicates): {accuracy_with_dup:.4f} ({correct_with_dup}/{total_with_dup})",
            f"Accuracy (unique video IDs): {accuracy_unique:.4f} ({correct_unique}/{total_unique})",
        ]

        # Print to console
        for l in lines:
            print(l)

        # Persist to file
        with out_path.open("w", encoding="utf-8") as f:
            for l in lines:
                f.write(l + "\n")

        print(f"{json_path.name}: dup {accuracy_with_dup:.4f} | unique {accuracy_unique:.4f}")


if __name__ == "__main__":
    main()
